{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boston_housing = datasets.load_boston()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boston_housing.feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data and train model\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(boston_housing.data, boston_housing.target, test_size=0.2, random_state=0)\n",
    "model = LinearRegression()\n",
    "model.fit(X_train, y_train)\n",
    "model.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from joblib import dump, load\n",
    "dump(model, '../models/boston_housing.joblib') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert into ONNX format with onnxmltools\n",
    "from skl2onnx import convert_sklearn\n",
    "from skl2onnx.common.data_types import FloatTensorType\n",
    "initial_type = [('float_input', FloatTensorType([1, 13]))]\n",
    "onx = convert_sklearn(model, initial_types=initial_type)\n",
    "with open(\"boston_housting.onnx\", \"wb\") as f:\n",
    "    f.write(onx.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the prediction with ONNX Runtime\n",
    "import onnxruntime as rt\n",
    "import numpy as np\n",
    "sess = rt.InferenceSession(\"boston_housting.onnx\")\n",
    "input_name = sess.get_inputs()[0].name\n",
    "label_name = sess.get_outputs()[0].name\n",
    "pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
